import torch
from typing import Optional, TypeVar, Callable, Any

from . import _nn as _nn
from . import _onnx as _onnx

T = TypeVar('T')

# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
    # TODO
    ...

# Defined in torch/csrc/autograd/python_legacy_variable.cpp
class _LegacyVariableBase(object):
    def __init__(
        self,
        data: Optional['torch.Tensor']=...,
        requires_grad: Optional[bool]=...,
        volatile: Optional[bool]=...,
        _grad_fn: Optional[_FunctionBase]=...
    ) -> None: ...

# Defined in torch/csrc/jit/python/init.cpp
def _jit_get_operation(op_name: str) -> Callable: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule') -> 'torch.jit.ScriptModule': ...

# Defined in torch/csrc/Module.cpp
def _show_config() -> str: ...
def _parallel_info() -> str: ...
def _add_docstr(obj: T, doc_obj: str) -> T: ...
def _from_dlpack(data: Any) -> 'torch.Tensor': ...
def _to_dlpack(data: 'torch.Tensor') -> Any: ...
def _set_backcompat_broadcast_warn(arg: bool) -> None: ...
def _get_backcompat_broadcast_warn() -> bool: ...
def _set_backcompat_keepdim_warn(arg: bool) -> None: ...
def _get_backcompat_keepdim_warn() -> bool: ...
def _is_xnnpack_enabled() -> bool: ...
def _get_mkldnn_enabled() -> bool: ...
def _set_mkldnn_enabled(arg: bool) -> None: ...
has_openmp: bool
has_mkldnn: bool
has_mkl: bool

# Defined in tools/autograd/templates/python_torch_functions.cpp
# TODO: This is technically wrong
class _VariableFunctions(object):
    # TODO
    ...

# Defined in torch/csrc/jit/python/script_init.cpp
class FileCheck(object):
    # TODO
    ...

# Defined in torch/csrc/Generator.cpp
class Generator(object):
    device: 'torch.device'
    def get_state(self) -> 'torch.Tensor': ...
    def set_state(self, _new_state: 'torch.Tensor') -> Generator: ...
    def manual_seed(self, seed: int) -> Generator: ...
    def seed(self) -> int: ...
    def initial_seed(self) -> int: ...

# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig(object):
    num_calling_threads: int
    num_worker_threads: int
    num_warmup_iters: int
    num_iters: int
    profiler_output_path: str

class BenchmarkExecutionStats(object):
    latency_avg_ms: float
    num_iters: int

class ThroughputBenchmark(object):
    def __init__(self, module: Any) -> None: ...
    def add_input(self, *args: Any, **kwargs: Any) -> None: ...
    def run_once(self, *args: Any, **kwargs: Any) -> Any: ...
    def benchmark(self, config: BenchmarkConfig) -> BenchmarkExecutionStats: ...

# Defined in torch/csrc/autograd/python_variable.cpp
# This is gonna need to be code'genned.
class _TensorBase(object):
    ...
